{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf0022a3-6769-4a28-a163-d5408344b23d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import scHash\n",
    "import anndata as ad\n",
    "\n",
    "from sklearn.metrics import f1_score, precision_score, recall_score\n",
    "from statistics import median"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "861967af-edc6-47be-8a2f-f66c6c5ebbc1",
   "metadata": {},
   "source": [
    "# Tutorial for atlas level cell annotations\n",
    "Here is an demonstration on atlas level dataset. We demonstrate the the atlas level annotation with the dataset Tabula Senis Muris and it can be download [here](https://figshare.com/projects/Tabula_Muris_Senis/64982). We followed scArches' preprocess pipeline and the preprocessed data can be downloaded through [here]( https://drive.google.com/file/d/1lfDu-TGsUvHrmXoSWkj0tptvWNYFgs2x/view?usp=share_link). The dataset contains 356213 cells with 5000 highly variable genes with cell type, method, and tissue annotations. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67da14ca-0fb5-4b5d-8599-1cd4384044cb",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf3f0bd2-3658-44d7-8ea4-c0e13e5309ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define data path\n",
    "data_dir = '../../../../share_data/Tabula_Muris_Senis(TM)/tabula_senis_normalized_all_hvg.h5ad'\n",
    "\n",
    "# This data contains both the reference and query source\n",
    "data = ad.read_h5ad(data_dir)\n",
    "\n",
    "# random split to get query indices\n",
    "# import random \n",
    "from sklearn.model_selection import train_test_split\n",
    "reference_indicies, query_indicies = train_test_split(list(range(data.shape[0])), train_size=0.8, stratify=data.obs.cell_ontology_class,random_state=42)\n",
    "\n",
    "train = data[reference_indicies]\n",
    "test = data[query_indicies]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b6cb2c4-7d29-4d06-ad8e-4ed36e2829d2",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Training Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85ade212-ca91-43c7-8182-9c4cbf248e58",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# set up the training datamodule\n",
    "datamodule = scHash.setup_training_data(train_data = train,cell_type_key = 'cell_ontology_class', batch_key = 'tissue',batch_size=512)\n",
    "\n",
    "# set a directory to save the model \n",
    "checkpointPath = '../checkpoint/'\n",
    "\n",
    "# initiliza scHash model and train \n",
    "model = scHash.scHashModel(datamodule)\n",
    "trainer, best_model_path, training_time = scHash.training(model = model, datamodule = datamodule, checkpointPath = checkpointPath, max_epochs = 200)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4b826ab-1ffb-494e-8037-9a02a6407e3d",
   "metadata": {},
   "source": [
    "## Test Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdecaada-8e97-46b5-94c9-6807c734061f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# add the test data\n",
    "datamodule.setup_test_data(test)\n",
    "\n",
    "# test the model\n",
    "pred_labels, hash_codes = scHash.testing(trainer, model, best_model_path)\n",
    "\n",
    "# show the test performance\n",
    "labels_true = test.obs.cell_ontology_class\n",
    "f1_median = round(median(f1_score(labels_true,pred_labels,average=None)),3)\n",
    "\n",
    "print(f'F1 Median: {f1_median}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jupyter_py3",
   "language": "python",
   "name": "jupyter_py3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
